from dataset.oxford_pet import Databasket, CUB

import torchvision
import torchvision.transforms as transforms
from LogME import LogME
import torchvision.models as models

from metrics import *
from net_forward import *
from pytorchcv.model_provider import get_model as ptcv_get_model
import simclr
import pandas as pd

from lightly.transforms import *
from ssl_model import *



def correlation(te_scores, gt, args):

    if args.metric == 'nc_cls_density_cb':
        def min_max_norm(data):
            if max(data) - min(data) == 0:
                normalized_data = [0]*len(data)
            else:
                normalized_data = [(x - min(data)) / (max(data) - min(data)) for x in data]
            return normalized_data
        def z_score_norm(data):
            mean = sum(data) / len(data)
            std_dev = (sum((x - mean)**2 for x in data) / len(data))**0.5
            normalized_data = [(x - mean) / std_dev for x in data]
            return normalized_data

        def list_add(a,b):
            c = []
            for i in range(len(a)):
                c.append(a[i]+b[i])
            return c

        nc = []
        density = []
        for tup in te_scores:
            nc.append(tup[0])
            density.append(tup[1])

        print(nc)
        print(density)

        rank_nc = min_max_norm(nc)
        rank_density = min_max_norm(density)
        # rank_nc = sorted(range(len(nc)), key=lambda k: nc[k], reverse=False)
        # rank_density = sorted(range(len(density)), key=lambda k: density[k], reverse=False)
        # print(rank_nc)
        # print(rank_density)

        te_scores = list_add(rank_nc, rank_density)

    print(te_scores)
    df = pd.DataFrame({'acc':gt,args.metric:te_scores})
    log_str = str(te_scores)
    torch.cuda.empty_cache()
    args.out_file.write(log_str+'\n')
    args.out_file.flush() 

    print(df.corr())


    log_str = str(df.corr())
    torch.cuda.empty_cache()
    args.out_file.write(log_str+'\n')
    args.out_file.flush() 


def cls_num_query(src):
    if src == 'imagenet':
        NUM_CLS_SRC = 1000
    elif src == 'cifar10':
        NUM_CLS_SRC = 10
    elif src == 'cifar100':
        NUM_CLS_SRC = 100
    elif src == 'imagenette':
        NUM_CLS_SRC = 10
    elif src == 'oxfordpets':
        NUM_CLS_SRC = 37
    elif src == 'oxfordflowers':
        NUM_CLS_SRC = 102
    elif src == 'CUB':
        NUM_CLS_SRC = 200
    elif src == 'DTD':
        NUM_CLS_SRC = 47
    elif src == 'food101':
        NUM_CLS_SRC = 101
    elif src == 'country211':
        NUM_CLS_SRC = 211
    elif src == 'place365':
        NUM_CLS_SRC = 365
    elif src == 'stanfordcars':
        NUM_CLS_SRC = 196 
    elif src == 'caltech101':
        NUM_CLS_SRC = 196 
    elif src == 'celeba':
        NUM_CLS_SRC = 40 
    elif src == 'fashionmnist':
        NUM_CLS_SRC = 10
    elif src == 'svhn':
        NUM_CLS_SRC = 10
    elif src == 'fgvcaircraft':
        NUM_CLS_SRC = 100
    elif src == 'gtsrb':
        NUM_CLS_SRC = 43
    elif src == 'inaturalist':
        NUM_CLS_SRC = 10000
    elif src == 'renderedsst2':
        NUM_CLS_SRC = 2
    elif src == 'stl10':
        NUM_CLS_SRC = 10

    return NUM_CLS_SRC

def build_tgt_datasets(args, ssl=False, src=False):

    if ssl:
        if args.ssl in ['dcl', 'simclr', 'nnclr']:
            transform_train = simclr_transform.SimCLRTransform(input_size=args.input_size)
        elif args.ssl in ['barlowtwins', 'tico']:
            transform_train = byol_transform.BYOLTransform(
                view_1_transform=byol_transform.BYOLView1Transform(input_size=args.input_size, gaussian_blur=0.0),
                view_2_transform=byol_transform.BYOLView2Transform(input_size=args.input_size, gaussian_blur=0.0),
            )
        elif args.ssl == 'dino':
            transform_train = dino_transform.DINOTransform(global_crop_size=args.input_size)
        elif args.ssl == 'fastsiam':
            transform_train = FastSiamTransform(input_size=args.input_size)
        elif args.ssl == 'moco':
            transform_train = MoCoV2Transform(input_size=args.input_size)
        elif args.ssl == 'simsiam':
            transform_train = SimSiamTransform(input_size=args.input_size)
        elif args.ssl == 'smog':
            transform_train = SMoGTransform(
                                crop_sizes=(args.input_size, args.input_size),
                                crop_counts=(1, 1),
                                gaussian_blur_probs=(0.0, 0.0),
                                crop_min_scales=(0.2, 0.2),
                                crop_max_scales=(1.0, 1.0),
                            )
        elif args.ssl in ['swav', 'swavq']:
            transform_train = SwaVTransform(crop_sizes=(args.input_size, args.input_size))
        elif args.ssl == 'vicreg':
            transform_train = VICRegTransform(input_size=args.input_size)
        elif args.ssl == 'vicregl':
            transform_train = VICRegLTransform(global_crop_size=args.input_size, n_local_views=0)

    else:
        transform_train = transforms.Compose([
            transforms.Resize([args.crop_size, args.crop_size]),
            transforms.RandomCrop([args.input_size, args.input_size]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])


    transform_test = transforms.Compose([
        transforms.Resize([args.input_size, args.input_size]),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    if src :
        dset = args.src
    else:
        dset = args.tgt
    

    if dset == 'cifar10':
        NUM_CLS_TGT = 10
        trainset = torchvision.datasets.CIFAR10(
            root='./data', train=True, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)

        testset = torchvision.datasets.CIFAR10(
            root='./data', train=False, download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)

        # classes = ('plane', 'car', 'bird', 'cat', 'deer',
        #         'dog', 'frog', 'horse', 'ship', 'truck')
    elif dset == 'cifar100':
        NUM_CLS_TGT = 100
        trainset = torchvision.datasets.CIFAR100(
            root='./data', train=True, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(
            trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)
        testset = torchvision.datasets.CIFAR100(
            root='./data', train=False, download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(
            testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
    elif dset == 'imagenette':
        NUM_CLS_TGT = 10
        # path = untar_data(URLs.IMAGENETTE_320)
        train_root = './data/imagenette/train'
        test_root = './data/imagenette/val'
        trainset = torchvision.datasets.ImageFolder(train_root, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)
        testset = torchvision.datasets.ImageFolder(train_root, transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)     
            
        # pass
    elif dset == 'oxfordpets':
        NUM_CLS_TGT = 37
        databasket = Databasket(train_transforms=transform_train, val_transforms=transform_test)
        trainset = databasket.train_ds
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)
        testset = databasket.val_ds
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
    elif dset == 'oxfordflowers':
        NUM_CLS_TGT = 102
        trainset = torchvision.datasets.Flowers102(root='./data/', split = 'train', transform = transform_train, download=False)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
        testset = torchvision.datasets.Flowers102(root='./data/', split = 'val', transform = transform_test, download=False)     
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers) 
        # exit()
        # pass
    elif dset == 'CUB':
        NUM_CLS_TGT = 200
        trainset = CUB(root='./data/CUB', is_train=True, transform=transform_train,)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)     
        testset = CUB(root='./data/CUB', is_train=False, transform=transform_test,)
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers) 
        # pass
    elif dset == 'SUN397':
        NUM_CLS_TGT = 397
        trainset = torchvision.datasets.SUN397(root='/data/yuhe.ding/DATA/SUN397', transform = transform_train, download=True)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)
        # testset = torchvision.datasets.SUN397(root='./data/yuhe.ding/DATA/SUN397', transform = transform_test, download=True)
        # testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)      
        exit()
        # pass
    elif dset == 'DTD':
        NUM_CLS_TGT = 47
        trainset = torchvision.datasets.DTD(root='/data/yuhe.ding/DATA/DTD', split = 'train', transform = transform_train, download=True)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)
        testset = torchvision.datasets.DTD(root='./data/yuhe.ding/DATA/DTD', split = 'val', transform = transform_test, download=True)
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)       
        # pass
    elif dset == 'food101':
        NUM_CLS_TGT = 101
        trainset = torchvision.datasets.Food101(root='./data/', split = 'train', transform = transform_train, download=True)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
        testset = torchvision.datasets.Food101(root='./data/', split = 'test', transform = transform_test, download=True)     
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers) 
    elif dset == 'country211':
        NUM_CLS_TGT = 211
        trainset = torchvision.datasets.Country211(root='./data/', split = 'train', transform = transform_train, download=True)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
        testset = torchvision.datasets.Country211(root='./data/', split = 'valid', transform = transform_test, download=True)     
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers) 
    elif dset == 'place365':
        NUM_CLS_TGT = 365
        trainset = torchvision.datasets.Places365(root='./data/', split = 'train-standard', transform = transform_train, download=True)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
        testset = torchvision.datasets.Places365(root='./data/', split = 'val', transform = transform_test, download=True)     
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
    elif dset == 'stanfordcars':
        NUM_CLS_TGT = 196 
        trainset = torchvision.datasets.StanfordCars(root='./data/', split = 'train', transform = transform_train, download=True)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
        testset = torchvision.datasets.StanfordCars(root='./data/', split = 'test', transform = transform_test, download=True)     
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
    elif dset == 'fgvcaircraft':
        NUM_CLS_TGT = 100
        trainset = torchvision.datasets.FGVCAircraft(root='./data/', split = 'train', transform = transform_train, download=True)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
        testset = torchvision.datasets.FGVCAircraft(root='./data/', split = 'test', transform = transform_test, download=True)     
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
    elif dset == 'gtsrb':
        NUM_CLS_TGT = 43
        trainset = torchvision.datasets.GTSRB(root='./data/', split = 'train', transform = transform_train, download=True)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
        testset = torchvision.datasets.GTSRB(root='./data/', split = 'test', transform = transform_test, download=True)     
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
    elif dset == 'svhn':
        NUM_CLS_TGT = 10
        trainset = torchvision.datasets.SVHN(root='./data/', split = 'train', transform = transform_train, download=True)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
        testset = torchvision.datasets.SVHN(root='./data/', split = 'test', transform = transform_test, download=True)     
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
    elif dset == 'stl10':
        NUM_CLS_TGT = 10
        trainset = torchvision.datasets.STL10(root='./data/', split = 'train', transform = transform_train, download=True)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
        testset = torchvision.datasets.STL10(root='./data/', split = 'test', transform = transform_test, download=True)     
        testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
    else:
        print(f'PLEASE CHECK THE NAME OF THE TARGET DATASET {dset}')
        exit()


    return trainset, trainloader, testset, testloader

def build_net_imgnet_pretrain(args, ckpt_src_dir):
    if args.net == 'resnet50':
        net = models.resnet50()
        ckpt_src = ckpt_src_dir + '/resnet50-v2.pth'
    elif args.net == 'resnet101':
        net = models.resnet101()
        ckpt_src = ckpt_src_dir + '/resnet101-v2.pth'
    elif args.net == 'resnet152':
        net = models.resnet152()
        ckpt_src = ckpt_src_dir + '/resnet152-v2.pth'
    elif args.net == 'densenet169':
        net = models.densenet169(pretrained=True)
        if not args.src == 'imagenet':
            ckpt_src = ckpt_src_dir + '/densenet169.pth' 
    elif args.net == 'densenet121':
        net = models.densenet121(pretrained=True)
        if not args.src == 'imagenet':
            ckpt_src = ckpt_src_dir + '/densenet121.pth'
    elif args.net == 'densenet201':
        net = models.densenet201(pretrained=True)
        if not args.src == 'imagenet':
            ckpt_src = ckpt_src_dir + '/densenet201.pth'
    elif args.net == 'mobilenetv1':
        net = ptcv_get_model("mobilenet_w1", pretrained=False)
        ckpt_src = ckpt_src_dir + '/mobilenet_w1-0895-7e1d739f.pth'
    elif args.net == 'mobilenetv2':
        net = models.mobilenet_v2()
        ckpt_src = ckpt_src_dir + '/mobilenet_v2.pth'
    elif args.net == 'mobilenetv3_large':
        net = models.mobilenet_v3_large()
        ckpt_src = ckpt_src_dir + '/mobilenet_v3_large-8738ca79.pth'
    elif args.net == 'mobilenetv3_small':
        net = models.mobilenet_v3_small()
        ckpt_src = ckpt_src_dir + '/mobilenet_v3_small-047dcff4.pth'
    elif args.net == 'swin_b':
        net = models.swin_b()
        ckpt_src = ckpt_src_dir + '/swin_b.pth'
    elif args.net == 'swin_v2_b':
        net = models.swin_v2_b()
        ckpt_src = ckpt_src_dir + '/swin_v2_b.pth'
    elif args.net == 'vit_b_16':
        net = models.vit_b_16()
        ckpt_src = ckpt_src_dir + '/vit_b_16.pth'
    elif args.net == 'wide_resnet101_2':
        net = models.wide_resnet101_2()
        ckpt_src = ckpt_src_dir + '/wide_resnet101_2.pth'
    elif args.net == 'efficientnetb0':
        net = models.efficientnet_b0()
        ckpt_src = ckpt_src_dir + '/efficientnet_b0_rwightman-3dd342df.pth'
    elif args.net == 'efficientnetb1':
        net = models.efficientnet_b1()
        ckpt_src = ckpt_src_dir + '/efficientnet_b1_rwightman-533bc792.pth'
    elif args.net == 'efficientnetb2':
        net = models.efficientnet_b2()
        ckpt_src = ckpt_src_dir + '/efficientnet_b2_rwightman-bcdf34b7.pth'
    elif args.net == 'efficientnetb3':
        net = models.efficientnet_b3()
        ckpt_src = ckpt_src_dir + '/efficientnet_b3_rwightman-cf984f9c.pth' 
    elif args.net == 'vgg16':
        net = models.vgg16()
        ckpt_src = ckpt_src_dir + '/vgg16-397923af.pth' 
    elif args.net == 'vgg19':
        net = models.vgg19()
        ckpt_src = ckpt_src_dir + '/vgg19-dcbb9e9d.pth'  
    else: 
        print(f"{args.net} is not defined.")
        exit()

    if args.net.startswith('densenet') and args.src == 'imagenet':
        return net , None
    else:
        return net, ckpt_src

def build_net_finetuned(args, net, ckpt_src):
    ckpt = torch.load(ckpt_src)

    if args.net.startswith('resnet'):
        net.fc = torch.nn.Linear(2048, args.num_cls_tgt)
    elif args.net.startswith('densenet'):
        input_dim = 1664
        if args.net == 'densenet201':
            input_dim = 1920
        elif args.net == 'densenet121':
            input_dim = 1024
        net.classifier = torch.nn.Linear(input_dim, args.num_cls_tgt)
    elif args.net.startswith('efficientnet'):
        input_dim = 1280
        if args.net == 'efficientnetb2':
            input_dim = 1408
        elif args.net == 'efficientnetb3':
            input_dim = 1536
        net.classifier = nn.Sequential(
                nn.Dropout(p=0.2, inplace=True),
                nn.Linear(input_dim, args.num_cls_tgt),
            )
    elif args.net.startswith('monilenet'):
        net.output = torch.nn.Linear(2048, args.num_cls_tgt)
    elif args.net.startswith('vgg'):
        net.classifier = nn.Sequential(
                nn.Linear(512 * 7 * 7, 4096),
                nn.ReLU(True),
                nn.Dropout(p=0.5),
                nn.Linear(4096, 4096),
                nn.ReLU(True),
                nn.Dropout(p=0.5),
                nn.Linear(4096, args.num_cls_tgt),
            )
        # net.classifier.6 = torch.nn.Linear(2048, NUM_CLS_TGT)
    
    net.load_state_dict(ckpt['net'], strict=True)

    return net

def build_net_ssl(args, ckpt_src_dir):

    print('==> Building model..')
    if args.net == 'resnet50':
        net = models.resnet50(pretrained=True)
        ckpt_src = ckpt_src_dir + '/resnet50-v2.pth'
        input_dim = 2048
    elif args.net == 'resnet101':
        net = models.resnet101()
        ckpt_src = ckpt_src_dir + '/resnet101-v2.pth'
        input_dim = 2048
    elif args.net == 'resnet152':
        net = models.resnet152()
        ckpt_src = ckpt_src_dir + '/resnet152-v2.pth'
        input_dim = 2048
    elif args.net == 'densenet169':
        net = models.densenet169(pretrained=True)
        if not args.src == 'imagenet':
            ckpt_src = ckpt_src_dir + '/densenet169.pth' 
    elif args.net == 'densenet121':
        net = models.densenet121(pretrained=True)
        if not args.src == 'imagenet':
            ckpt_src = ckpt_src_dir + '/densenet121.pth'
    elif args.net == 'densenet201':
        net = models.densenet201(pretrained=True)
        if not args.src == 'imagenet':
            ckpt_src = ckpt_src_dir + '/densenet201.pth'
    elif args.net == 'mobilenetv1':
        net = ptcv_get_model("mobilenet_w1", pretrained=False)
        ckpt_src = ckpt_src_dir + '/mobilenet_w1-0895-7e1d739f.pth' 
    elif args.net == 'mobilenetv2':
        net = models.mobilenet_v2()
        ckpt_src = ckpt_src_dir + '/mobilenet_v2.pth'
    elif args.net == 'mobilenetv3_large': 
        net = models.mobilenet_v3_large() 
        ckpt_src = ckpt_src_dir + '/mobilenet_v3_large-8738ca79.pth' 
    elif args.net == 'mobilenetv3_small':  
        net = models.mobilenet_v3_small()
        ckpt_src = ckpt_src_dir + '/mobilenet_v3_small-047dcff4.pth'  
    elif args.net == 'swin_b':
        net = models.swin_b()
        ckpt_src = ckpt_src_dir + '/swin_b.pth'
    elif args.net == 'swin_v2_b':
        net = models.swin_v2_b()
        ckpt_src = ckpt_src_dir + '/swin_v2_b.pth'
    elif args.net == 'vit_b_16':
        net = models.vit_b_16()
        ckpt_src = ckpt_src_dir + '/vit_b_16.pth'
    elif args.net == 'wide_resnet101_2':
        net = models.wide_resnet101_v2()
        ckpt_src = ckpt_src_dir + '/wide_resnet101_2.pth'
    elif args.net == 'efficientnetb0':
        net = models.efficientnet_b0(pretrained=True)
        ckpt_src = ckpt_src_dir + '/efficientnet_b0_rwightman-3dd342df.pth'
    elif args.net == 'efficientnetb1':
        net = models.efficientnet_b1() 
        ckpt_src = ckpt_src_dir + '/efficientnet_b1_rwightman-533bc792.pth'
    elif args.net == 'efficientnetb2':
        net = models.efficientnet_b2(pretrained=True)
        ckpt_src = ckpt_src_dir + '/efficientnet_b2_rwightman-bcdf34b7.pth' 
    elif args.net == 'efficientnetb3':
        net = models.efficientnet_b3(pretrained=True)
        ckpt_src = ckpt_src_dir + '/efficientnet_b3_rwightman-cf984f9c.pth'
    elif args.net == 'vgg16':
        net = models.vgg16()
        ckpt_src = ckpt_src_dir + '/vgg16-397923af.pth'
    elif args.net == 'vgg19':
        net = models.vgg19()
        ckpt_src = ckpt_src_dir + '/vgg19-dcbb9e9d.pth'


    net.fc = torch.nn.Linear(2048, args.num_cls_src)

    if args.ssl in ['dcl', 'simclr']:
        backbone = nn.Sequential(*list(net.children())[:-1])
        net = DCL(backbone, input_dim)
    elif args.ssl in ['barlowtwins', 'vicreg']:
        backbone = nn.Sequential(*list(net.children())[:-1])
        net = BarlowTwins(backbone, input_dim)
    elif args.ssl == 'dino':
        backbone = nn.Sequential(*list(net.children())[:-1])
        net = DINO(backbone, input_dim)
    elif args.ssl == 'fastsiam':
        backbone = nn.Sequential(*list(net.children())[:-1])
        net = FastSiam(backbone, input_dim)
    elif args.ssl == 'moco':        
        backbone = nn.Sequential(*list(net.children())[:-1])
        net = MoCo(backbone, input_dim)
    elif args.ssl == 'nnclr':
        backbone = nn.Sequential(*list(net.children())[:-1])
        net = NNCLR(backbone, input_dim)
    elif args.ssl == 'simsiam':
        backbone = nn.Sequential(*list(net.children())[:-1])
        net = SimSiam(backbone, input_dim)
    elif args.ssl == 'smog':
        backbone = nn.Sequential(*list(net.children())[:-1])
        net = SMoGModel(backbone, input_dim)
    elif args.ssl == 'swav':
        backbone = nn.Sequential(*list(net.children())[:-1])
        net = SwaV(backbone, input_dim)
    elif args.ssl == 'swavq':
        backbone = nn.Sequential(*list(net.children())[:-1])
        net = SwaV_Queue(backbone, input_dim)
    elif args.ssl == 'tico':        
        backbone = nn.Sequential(*list(net.children())[:-1])
        net = TiCo(backbone, input_dim)
    elif args.ssl == 'vicregl':        
        backbone = nn.Sequential(*list(net.children())[:-1])
        net = VICRegL(backbone, input_dim)



    return net, ckpt_src

def build_forward_func(args):
    if args.net.startswith('resnet'):
        forward_func = res_forward
    elif args.net.startswith('densenet'):
        forward_func = dense_forward
    elif args.net.startswith('mobile'):
        if args.net == 'mobilenetv1':
            forward_func = mobilenetv1_forward
        elif args.net == 'mobilenetv3_large':
            forward_func = mobilenetv3_forward
        else:
            forward_func = mobile_forward
    elif args.net.startswith('swin'):
        forward_func = swin_forward
    elif args.net.startswith('efficientnet'):
        forward_func = efficientnet_forward
    elif args.net.startswith('vgg'):
        forward_func = vgg_forward
    return forward_func

def build_estimater(args):
    if args.metric == 'FaCe':
        Estimate =  nc_cls_density_Estimate_cb
    return Estimate